#!/usr/bin/env python3

from http.server import SimpleHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn
import socket
import sys
import threading
import os
import traceback
import urllib.request
import subprocess


def is_ipv6(host):
    try:
        socket.inet_aton(host)
        return False
    except:
        return True


def get_local_port(host, ipv6):
    if ipv6:
        family = socket.AF_INET6
    else:
        family = socket.AF_INET

    with socket.socket(family) as fd:
        fd.bind((host, 0))
        return fd.getsockname()[1]


CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "localhost")
CLICKHOUSE_PORT_HTTP = os.environ.get("CLICKHOUSE_PORT_HTTP", "8123")

# Server returns this JSON response.
SERVER_JSON_RESPONSE = """{
	"login": "ClickHouse",
	"id": 54801242,
	"name": "ClickHouse",
	"company": null
}"""

EXPECTED_ANSWER = """{\\n\\t"login": "ClickHouse",\\n\\t"id": 54801242,\\n\\t"name": "ClickHouse",\\n\\t"company": null\\n}"""

#####################################################################################
# This test starts an HTTP server and serves data to clickhouse url-engine based table.
# The objective of this test is to check the ClickHouse server provides a User-Agent
# with HTTP requests.
# In order for it to work ip+port of http server (given below) should be
# accessible from clickhouse server.
#####################################################################################

# IP-address of this host accessible from the outside world. Get the first one
HTTP_SERVER_HOST = (
    subprocess.check_output(["hostname", "-i"]).decode("utf-8").strip().split()[0]
)
IS_IPV6 = is_ipv6(HTTP_SERVER_HOST)
HTTP_SERVER_PORT = get_local_port(HTTP_SERVER_HOST, IS_IPV6)

# IP address and port of the HTTP server started from this script.
HTTP_SERVER_ADDRESS = (HTTP_SERVER_HOST, HTTP_SERVER_PORT)
if IS_IPV6:
    HTTP_SERVER_URL_STR = (
        "http://"
        + f"[{str(HTTP_SERVER_ADDRESS[0])}]:{str(HTTP_SERVER_ADDRESS[1])}"
        + "/"
    )
else:
    HTTP_SERVER_URL_STR = (
        "http://" + f"{str(HTTP_SERVER_ADDRESS[0])}:{str(HTTP_SERVER_ADDRESS[1])}" + "/"
    )


def get_ch_answer(query):
    host = CLICKHOUSE_HOST
    if IS_IPV6:
        host = f"[{host}]"

    url = os.environ.get(
        "CLICKHOUSE_URL",
        "http://{host}:{port}".format(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT_HTTP),
    )
    return urllib.request.urlopen(url, data=query.encode()).read().decode()


def check_answers(query, answer):
    ch_answer = get_ch_answer(query)
    if ch_answer.strip() != answer.strip():
        print("FAIL on query:", query, file=sys.stderr)
        print("Expected answer:", answer, file=sys.stderr)
        print("Fetched answer :", ch_answer, file=sys.stderr)
        raise Exception("Fail on query")


# Server with check for User-Agent headers.
class HttpProcessor(SimpleHTTPRequestHandler):
    def _set_headers(self):
        user_agent = self.headers.get("User-Agent")
        if user_agent and user_agent.startswith("ClickHouse/"):
            self.send_response(200)
        else:
            self.send_response(403)

        self.send_header("Content-Type", "text/csv")
        self.end_headers()

    def do_GET(self):
        self._set_headers()
        self.wfile.write(SERVER_JSON_RESPONSE.encode())

    def log_message(self, format, *args):
        return


class HTTPServerV6(HTTPServer):
    address_family = socket.AF_INET6


class ThreadedHTTPServer(ThreadingMixIn, HTTPServer):
    pass


class ThreadedHTTPServerV6(ThreadingMixIn, HTTPServerV6):
    pass


def start_server(requests_amount):
    if IS_IPV6:
        httpd = ThreadedHTTPServerV6(HTTP_SERVER_ADDRESS, HttpProcessor)
    else:
        httpd = ThreadedHTTPServer(HTTP_SERVER_ADDRESS, HttpProcessor)

    def real_func():
        for i in range(requests_amount):
            httpd.handle_request()

    t = threading.Thread(target=real_func)
    return t


#####################################################################
# Testing area.
#####################################################################


def test_select():
    global HTTP_SERVER_URL_STR
    query = "SELECT * FROM url('{}','JSONAsString');".format(HTTP_SERVER_URL_STR)
    check_answers(query, EXPECTED_ANSWER)


def main():
    # HEAD + GET
    t = start_server(2)
    t.start()
    test_select()
    t.join()
    print("PASSED")


if __name__ == "__main__":
    try:
        main()
    except Exception as ex:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        traceback.print_tb(exc_traceback, file=sys.stderr)
        print(ex, file=sys.stderr)
        sys.stderr.flush()

        os._exit(1)
